{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
  "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3",
   "language": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import scipy.io\n",
    "import urllib.request\n",
    "import dgl\n",
    "import math\n",
    "import numpy as np\n",
    "from model import *\n",
    "import torch\n",
    "from data_loader import data_loader\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "4051"
      ]
     },
     "metadata": {},
     "execution_count": 16
    }
   ],
   "source": [
    "dataset = data_loader('../DBLP')\n",
    "edge_dict = {}\n",
    "\n",
    "for i, meta_path in dataset.links['meta'].items():\n",
    "    edge_dict[(str(meta_path[0]), str(meta_path[0]) + '_' + str(meta_path[1]), str(meta_path[1]))] = (torch.tensor(dataset.links['data'][i].tocoo().row - dataset.nodes['shift'][meta_path[0]]), torch.tensor(dataset.links['data'][i].tocoo().col - dataset.nodes['shift'][meta_path[1]]))\n",
    "np.nonzero(dataset.labels_train['mask'])[0].max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "4057\n0 4057\n1 14328\n2 7723\n3 20\n"
     ]
    }
   ],
   "source": [
    "print(dataset.nodes['count'][0])\n",
    "node_dict = {}\n",
    "for i, node_count in dataset.nodes['count'].items():\n",
    "    print(i, node_count)\n",
    "    node_dict[str(i)] = node_count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "output_type": "error",
     "ename": "KeyError",
     "evalue": "'0'",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-71-a3dcc0ec2ba0>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mG\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdgl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheterograph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0medge_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_nodes_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnode_dict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/anaconda3/envs/fwz_py3/lib/python3.7/site-packages/dgl/convert.py\u001b[0m in \u001b[0;36mheterograph\u001b[0;34m(data_dict, num_nodes_dict, idtype, device)\u001b[0m\n\u001b[1;32m    313\u001b[0m             \u001b[0mnum_nodes_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mdty\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_nodes_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mdty\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvrange\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    314\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# sanity check\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 315\u001b[0;31m             \u001b[0;32mif\u001b[0m \u001b[0mnum_nodes_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0msty\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0murange\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    316\u001b[0m                 raise DGLError('The given number of nodes of node type {} must be larger than'\n\u001b[1;32m    317\u001b[0m                                ' the max ID in the data, but got {} and {}.'.format(\n",
      "\u001b[0;31mKeyError\u001b[0m: '0'"
     ]
    }
   ],
   "source": [
    "G = dgl.heterograph(edge_dict, num_nodes_dict = node_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "node_dict = {}\n",
    "edge_dict = {}\n",
    "for ntype in G.ntypes:\n",
    "    node_dict[ntype] = len(node_dict)\n",
    "for etype in G.etypes:\n",
    "    edge_dict[etype] = len(edge_dict)\n",
    "    G.edges[etype].data['id'] = torch.ones(G.number_of_edges(etype), dtype=torch.long) * edge_dict[etype] \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "output_type": "error",
     "ename": "DGLError",
     "evalue": "Node type \"ntype\" does not exist.",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mDGLError\u001b[0m                                  Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-68-3033e4d996e4>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m     \u001b[0memb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mParameter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mG\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumber_of_nodes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mntype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m256\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequires_grad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mxavier_uniform_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0memb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m     \u001b[0mG\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnodes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'ntype'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'inp'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0memb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/anaconda3/envs/fwz_py3/lib/python3.7/site-packages/dgl/view.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m     35\u001b[0m             \u001b[0mnodes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     36\u001b[0m             \u001b[0mntype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 37\u001b[0;31m         \u001b[0mntid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_typeid_getter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mntype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     38\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mNodeSpace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mHeteroNodeDataView\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mntype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mntid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnodes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     39\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/fwz_py3/lib/python3.7/site-packages/dgl/heterograph.py\u001b[0m in \u001b[0;36mget_ntype_id\u001b[0;34m(self, ntype)\u001b[0m\n\u001b[1;32m   1075\u001b[0m         \u001b[0mntid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_srctypes_invmap\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mntype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dsttypes_invmap\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mntype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1076\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mntid\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1077\u001b[0;31m             \u001b[0;32mraise\u001b[0m \u001b[0mDGLError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Node type \"{}\" does not exist.'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mntype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1078\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mntid\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1079\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mDGLError\u001b[0m: Node type \"ntype\" does not exist."
     ]
    }
   ],
   "source": [
    "for ntype in G.ntypes:\n",
    "    emb = nn.Parameter(torch.Tensor(G.number_of_nodes(ntype), 256), requires_grad = False)\n",
    "    nn.init.xavier_uniform_(emb)\n",
    "    G.nodes[ntype].data['inp'] = emb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "G."
   ]
  }
 ]
}